# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import json
import warnings
from typing import Any, Dict, List

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses.chat_message import ChatMessage, ToolCall
from haystack.dataclasses.tool import Tool, ToolInvocationError, _check_duplicate_tool_names, deserialize_tools_inplace

logger = logging.getLogger(__name__)

_TOOL_INVOCATION_FAILURE = "Tool invocation failed with error: {error}."
_TOOL_NOT_FOUND = "Tool {tool_name} not found in the list of tools. Available tools are: {available_tools}."
_TOOL_RESULT_CONVERSION_FAILURE = (
    "Failed to convert tool result to string using '{conversion_function}'. Error: {error}."
)


class ToolNotFoundException(Exception):
    """
    Exception raised when a tool is not found in the list of available tools.
    """

    pass


class StringConversionError(Exception):
    """
    Exception raised when the conversion of a tool result to a string fails.
    """

    pass


@component
class ToolInvoker:
    """
    Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.

    At initialization, the ToolInvoker component is provided with a list of available tools.
    At runtime, the component processes a list of ChatMessage object containing tool calls
    and invokes the corresponding tools.
    The results of the tool invocations are returned as a list of ChatMessage objects with tool role.

    Usage example:
    ```python
    from haystack.dataclasses import ChatMessage, ToolCall, Tool
    from haystack.components.tools import ToolInvoker

    # Tool definition
    def dummy_weather_function(city: str):
        return f"The weather in {city} is 20 degrees."

    parameters = {"type": "object",
                "properties": {"city": {"type": "string"}},
                "required": ["city"]}

    tool = Tool(name="weather_tool",
                description="A tool to get the weather",
                function=dummy_weather_function,
                parameters=parameters)

    # Usually, the ChatMessage with tool_calls is generated by a Language Model
    # Here, we create it manually for demonstration purposes
    tool_call = ToolCall(
        tool_name="weather_tool",
        arguments={"city": "Berlin"}
    )
    message = ChatMessage.from_assistant(tool_calls=[tool_call])

    # ToolInvoker initialization and run
    invoker = ToolInvoker(tools=[tool])
    result = invoker.run(messages=[message])

    print(result)
    ```

    ```
    >>  {
    >>      'tool_messages': [
    >>          ChatMessage(
    >>              _role=<ChatRole.TOOL: 'tool'>,
    >>              _content=[
    >>                  ToolCallResult(
    >>                      result='"The weather in Berlin is 20 degrees."',
    >>                      origin=ToolCall(
    >>                          tool_name='weather_tool',
    >>                          arguments={'city': 'Berlin'},
    >>                          id=None
    >>                      )
    >>                  )
    >>              ],
    >>              _meta={}
    >>          )
    >>      ]
    >>  }
    ```
    """

    def __init__(self, tools: List[Tool], raise_on_failure: bool = True, convert_result_to_json_string: bool = False):
        """
        Initialize the ToolInvoker component.

        :param tools:
            A list of tools that can be invoked.
        :param raise_on_failure:
            If True, the component will raise an exception in case of errors
            (tool not found, tool invocation errors, tool result conversion errors).
            If False, the component will return a ChatMessage object with `error=True`
            and a description of the error in `result`.
        :param convert_result_to_json_string:
            If True, the tool invocation result will be converted to a string using `json.dumps`.
            If False, the tool invocation result will be converted to a string using `str`.

        :raises ValueError:
            If no tools are provided or if duplicate tool names are found.
        """

        msg = "The `ToolInvoker` component is experimental and its API may change in the future."
        warnings.warn(msg)

        if not tools:
            raise ValueError("ToolInvoker requires at least one tool to be provided.")
        _check_duplicate_tool_names(tools)

        self.tools = tools
        self._tools_with_names = dict(zip([tool.name for tool in tools], tools))
        self.raise_on_failure = raise_on_failure
        self.convert_result_to_json_string = convert_result_to_json_string

    def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall) -> ChatMessage:
        """
        Prepares a ChatMessage with the result of a tool invocation.

        :param result:
            The tool result.
        :returns:
            A ChatMessage object containing the tool result as a string.

        :raises
            StringConversionError: If the conversion of the tool result to a string fails
            and `raise_on_failure` is True.
        """
        error = False

        if self.convert_result_to_json_string:
            try:
                # We disable ensure_ascii so special chars like emojis are not converted
                tool_result_str = json.dumps(result, ensure_ascii=False)
            except Exception as e:
                if self.raise_on_failure:
                    raise StringConversionError("Failed to convert tool result to string using `json.dumps`") from e
                tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="json.dumps")
                error = True
            return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)

        try:
            tool_result_str = str(result)
        except Exception as e:
            if self.raise_on_failure:
                raise StringConversionError("Failed to convert tool result to string using `str`") from e
            tool_result_str = _TOOL_RESULT_CONVERSION_FAILURE.format(error=e, conversion_function="str")
            error = True
        return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)

    @component.output_types(tool_messages=List[ChatMessage])
    def run(self, messages: List[ChatMessage]) -> Dict[str, Any]:
        """
        Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.

        :param messages:
            A list of ChatMessage objects.
        :returns:
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
            Each ChatMessage objects wraps the result of a tool invocation.

        :raises ToolNotFoundException:
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
        :raises ToolInvocationError:
            If the tool invocation fails and `raise_on_failure` is True.
        :raises StringConversionError:
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
        """
        tool_messages = []

        for message in messages:
            tool_calls = message.tool_calls
            if not tool_calls:
                continue

            for tool_call in tool_calls:
                tool_name = tool_call.tool_name
                tool_arguments = tool_call.arguments

                if not tool_name in self._tools_with_names:
                    msg = _TOOL_NOT_FOUND.format(tool_name=tool_name, available_tools=self._tools_with_names.keys())
                    if self.raise_on_failure:
                        raise ToolNotFoundException(msg)
                    tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
                    continue

                tool_to_invoke = self._tools_with_names[tool_name]
                try:
                    tool_result = tool_to_invoke.invoke(**tool_arguments)
                except ToolInvocationError as e:
                    if self.raise_on_failure:
                        raise e
                    msg = _TOOL_INVOCATION_FAILURE.format(error=e)
                    tool_messages.append(ChatMessage.from_tool(tool_result=msg, origin=tool_call, error=True))
                    continue

                tool_message = self._prepare_tool_result_message(tool_result, tool_call)
                tool_messages.append(tool_message)

        return {"tool_messages": tool_messages}

    def to_dict(self) -> Dict[str, Any]:
        """
        Serializes the component to a dictionary.

        :returns:
            Dictionary with serialized data.
        """
        serialized_tools = [tool.to_dict() for tool in self.tools]
        return default_to_dict(
            self,
            tools=serialized_tools,
            raise_on_failure=self.raise_on_failure,
            convert_result_to_json_string=self.convert_result_to_json_string,
        )

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
        """
        Deserializes the component from a dictionary.

        :param data:
            The dictionary to deserialize from.
        :returns:
            The deserialized component.
        """
        deserialize_tools_inplace(data["init_parameters"], key="tools")
        return default_from_dict(cls, data)
